from typing_extensions import override

from dataclasses import dataclass

from pydantic import BaseModel
from ten_ai_base.asr import AsyncASRBaseExtension
from ten_ai_base.message import (
    ModuleError,
    ModuleErrorCode,
    ModuleType,
)
from ten_runtime import (
    AsyncTenEnv,
    AudioFrame,
)


@dataclass
class DefaultASRConfig(BaseModel):
    pass


class DefaultASRExtension(AsyncASRBaseExtension):
    def __init__(self, name: str):
        super().__init__(name)
        self.config: DefaultASRConfig | None = None

    @override
    def vendor(self) -> str:
        raise NotImplementedError("Vendor method not implemented")

    async def on_init(self, ten_env: AsyncTenEnv) -> None:
        await super().on_init(ten_env)

        ten_env.log_info("DefaultASRExtension on_init")

        config_json, _ = await self.ten_env.get_property_to_json("")

        try:
            self.config = DefaultASRConfig.model_validate_json(config_json)
        except Exception as e:
            await self._handle_error(e)

    @override
    async def start_connection(self) -> None:
        raise NotImplementedError("start_connection method not implemented")

    @override
    async def send_audio(
        self, frame: AudioFrame, session_id: str | None
    ) -> None:
        raise NotImplementedError("send_audio method not implemented")

    @override
    async def finalize(self, session_id: str | None) -> None:
        pass

    @override
    async def stop_connection(self) -> None:
        raise NotImplementedError("stop_connection method not implemented")

    @override
    def is_connected(self) -> bool:
        raise NotImplementedError("is_connected method not implemented")

    def input_audio_sample_rate(self) -> int:
        raise NotImplementedError(
            "input_audio_sample_rate method not implemented"
        )

    async def _handle_error(self, error: Exception):
        self.ten_env.log_error(f"Default error: {error}")
        await self.send_asr_error(
            ModuleError(
                module=ModuleType.ASR,
                code=ModuleErrorCode.FATAL_ERROR.value,
                message=str(error),
            ),
        )
